import torch

str_to_torch_dtype = {
    "bfloat16": torch.bfloat16,
    "float16": torch.float16,
}
